/*******************************************************************************
 * Copyright (C) 2018-2019 Intel Corporation
 *
 * SPDX-License-Identifier: MIT
 ******************************************************************************/

#include "identify.h"

#include "config.h"
#include "gstgvaidentify.h"
#include "gva_roi_meta.h"
#include "gva_tensor_meta.h"
#include "gva_utils.h"

#include <opencv2/imgproc.hpp>
#include <vector>

#include "logger_functions.h"

#define UNUSED(x) (void)x

using namespace std::placeholders;

Identify::Identify(GstGvaIdentify *ovino) : frame_num(0) {
    masterGstElement_ = ovino;

    // Create gallery
    if (ovino->gallery) {
        double reid_threshold = ovino->threshold;
        gallery = std::unique_ptr<EmbeddingsGallery>(new EmbeddingsGallery(ovino->gallery, reid_threshold));
    }

    // Create tracker. TODO expose some of parameters as GST properties
    if (ovino->tracker) {
        TrackerParams tracker_reid_params;
        tracker_reid_params.min_track_duration = 1;
        tracker_reid_params.forget_delay = 150;
        tracker_reid_params.affinity_thr = 0.8;
        tracker_reid_params.averaging_window_size = 1;
        tracker_reid_params.bbox_heights_range = cv::Vec2f(10, 1080);
        tracker_reid_params.drop_forgotten_tracks = false;
        tracker_reid_params.max_num_objects_in_track = std::numeric_limits<int>::max();
        tracker_reid_params.objects_type = "face";
        tracker = std::unique_ptr<Tracker>(new Tracker(tracker_reid_params));
    }
}

Identify::~Identify() {
}

void Identify::IdentifyObjects(GstBuffer *buffer) {
    GVA::RegionOfInterestList roi_list(buffer);
    std::vector<cv::Mat> embeddings;
    std::vector<GVA::Tensor *> tensors;

    // Find embeddings generated by gvaclassify
    for (int i = 0; i < roi_list.NumberObjects(); i++) {
        GVA::RegionOfInterest &roi = roi_list[i];
        // Find tensor generated by reidentification model
        for (GVA::Tensor &tensor : roi) {
            auto s = tensor.model_name();
            if (masterGstElement_->model) {
                if (tensor.model_name().find(masterGstElement_->model) == std::string::npos)
                    continue;
            } else {
                if (tensor.format() != "cosine_distance")
                    continue;
            }
            // embeddings
            std::vector<float> data = tensor.data<float>();
            cv::Mat blob_wrapper(data.size(), 1, CV_32F, data.data());
            embeddings.emplace_back();
            blob_wrapper.copyTo(embeddings.back());
            // tensors
            tensors.push_back(&tensor);
            break;
        }
    }

    // Compare embeddings against gallery
    auto ids = gallery->GetIDsByEmbeddings(embeddings);

    // Store label for identified objects
    for (size_t i = 0, j = 0; i < ids.size() && j < tensors.size(); i++) {
        if (ids[i] != EmbeddingsGallery::unknown_id) {
            tensors[j]->set_string("label", gallery->GetLabelByID(ids[i]));
            tensors[j]->set_int("label_id", ids[i] + 1); // recognized objects starting label_id=1
            j++;
        }
    }
}

void Identify::TrackObjects(GstBuffer *buffer, GstVideoInfo *info) {
    GVA::RegionOfInterestList roi_list(buffer);
    std::vector<TrackedObject> tracked_objects;

    for (int i = 0; i < roi_list.NumberObjects(); i++) {
        GVA::RegionOfInterest &roi = roi_list[i];
        GstVideoRegionOfInterestMeta *meta = roi.meta();
        cv::Rect rect(meta->x, meta->y, meta->w, meta->h);

        TrackedObject obj(rect, roi.confidence(), -1, i, -1);
        for (GVA::Tensor &tensor : roi) {
            if (tensor.has_field("label_id")) {
                obj.label = tensor.get_int("label_id");
                break;
            }
        }
        tracked_objects.push_back(obj);
    }

    frame_num++;
    tracker->Process(cv::Size(info->width, info->height), tracked_objects, frame_num);
    tracked_objects = tracker->TrackedDetections();

    // Set object id in metadata
    for (TrackedObject &obj : tracked_objects) {
        if (obj.object_index < 0 || obj.object_index >= roi_list.NumberObjects())
            continue;
        if (obj.object_id >= 0) {
            roi_list[obj.object_index].meta()->id = obj.object_id + 1; // tracked objects starting id=1
        }
    }
}

void Identify::ProcessOutput(GstBuffer *buffer, GstVideoInfo *info) {
    if (gallery) {
        IdentifyObjects(buffer);
    }
    if (tracker) {
        TrackObjects(buffer, info);
    }
}

GstFlowReturn frame_to_identify(GstGvaIdentify *ovino, GstBuffer *buf, GstVideoInfo *info) {
    ovino->identifier->ProcessOutput(buf, info);
    return GST_FLOW_OK;
}

Identify *identifier_new(GstGvaIdentify *ovino) {
    return new Identify(ovino);
}

void identifier_delete(Identify *identifier) {
    delete identifier;
}
